{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68207314-0f2a-4312-81b7-366f7e7adecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install timm==1.0.9\n",
    "# !pip install albumentations==1.4.14\n",
    "# !pip install torcheval==0.0.7\n",
    "# !pip install pandas==2.2.2\n",
    "# !pip install numpy==1.26.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "447003d4-9801-41d8-ae94-c7057ed0e906",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os, time, copy, gc\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import albumentations as A\n",
    "from albumentations.pytorch import ToTensorV2\n",
    "import multiprocessing as mp\n",
    "\n",
    "from torcheval.metrics.functional import binary_auroc, multiclass_auroc\n",
    "\n",
    "import hashlib\n",
    "from joblib import Parallel, delayed\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from PIL import Image\n",
    "import torch.optim as optim\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "sys.path.append('../src')\n",
    "from utils import set_seed, visualize_augmentations_positive, print_trainable_parameters\n",
    "from models import setup_model\n",
    "from training import fetch_scheduler, train_one_epoch, valid_one_epoch\n",
    "from models import ISICModel, ISICModelEdgnet\n",
    "from datasets import ISICDatasetSamplerW, ISICDatasetSampler, ISICDatasetSimple, ISICDatasetSamplerMulticlass\n",
    "from augmentations import get_augmentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf6f8e2-501d-4adc-9e38-e21251662b10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up device and random seed\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
    "    print(f\"Number of GPUs: {torch.cuda.device_count()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64eef977-0185-4821-ab7f-76ce9dbe35f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = \"../models/pretraining\"\n",
    "str_model_name = \"ema_small_pretrained_medium\"\n",
    "os.makedirs(model_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca582129-ceae-47fb-81a7-3ac19ea4b15f",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = {\n",
    "    \"seed\": 42, #42 33\n",
    "    \"epochs\": 500,\n",
    "    \"img_size\": 336, #336,\n",
    "    \"train_batch_size\": 32,\n",
    "    \"valid_batch_size\": 64,\n",
    "    \"learning_rate\": 1e-4,\n",
    "    \"scheduler\": 'CosineAnnealingLR',\n",
    "    \"min_lr\": 1e-6,\n",
    "    \"T_max\": 2000,\n",
    "    \"weight_decay\": 1e-6,\n",
    "    \"fold\" : 0,\n",
    "    \"n_fold\": 5,\n",
    "    \"n_accumulate\": 1,\n",
    "    \"group_col\": 'patient_id',\n",
    "    \"device\": device\n",
    "}\n",
    "\n",
    "model_name = \"eva02_small_patch14_336.mim_in22k_ft_in1k\"\n",
    "checkpoint_path = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78b7d22e-df9e-41f5-9f59-0b279222ede6",
   "metadata": {},
   "outputs": [],
   "source": [
    "original_data_path = \"../data/original\"\n",
    "original_root = Path('../data/original')\n",
    "\n",
    "data_artifacts = \"../data/artifacts\"\n",
    "os.makedirs(data_artifacts, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab3a96ba-f348-4b85-923a-856cb45d0e2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_HDF5_FILE_PATH = original_root / 'train-image.hdf5'\n",
    "\n",
    "train_path = original_root / 'train-metadata.csv'\n",
    "df_train = pd.read_csv(train_path)\n",
    "df_train[\"path\"] = '../data/original/train-image/image/' + df_train['isic_id'] + \".jpg\"\n",
    "original_positive_cases = df_train['target'].sum()\n",
    "original_total_cases = len(df_train)\n",
    "original_positive_ratio = original_positive_cases / original_total_cases\n",
    "\n",
    "print(f\"Number of positive cases: {original_positive_cases}\")\n",
    "print(f\"Number of negative cases: {original_total_cases - original_positive_cases}\")\n",
    "print(f\"Ratio of negative to positive cases: {(original_total_cases - original_positive_cases) / original_positive_cases:.2f}:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b31f5577-07a6-4e3c-9926-d71c41d564ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_transforms = get_augmentations(CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f710dc-4e01-4204-af71-15a94223eb52",
   "metadata": {},
   "outputs": [],
   "source": [
    "aug_transform_base = A.Compose([\n",
    "    A.Resize(CONFIG['img_size'], CONFIG['img_size']),\n",
    "    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ToTensorV2(),\n",
    "])\n",
    "\n",
    "aug_transform = A.Compose([\n",
    "    A.RandomRotate90(),\n",
    "    A.Flip(),\n",
    "    A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.1, p=0.5),\n",
    "    A.Resize(CONFIG['img_size'], CONFIG['img_size']),\n",
    "    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ToTensorV2(),\n",
    "])\n",
    "\n",
    "\n",
    "augtest_dataset = ISICDatasetSampler(\n",
    "    meta_df=df_train,\n",
    "    # transforms=aug_transform_base,\n",
    "    do_augmentations=True,\n",
    "    transforms=data_transforms['train'] # look to extreme sometimes but works quite good\n",
    ")\n",
    "\n",
    "# visualize_augmentations_positive(augtest_dataset, transforms=aug_transform,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e123979-ab12-43ee-99f7-1c5d960e0330",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata_df = pd.read_csv(\"../images/metadata.csv\")\n",
    "metadata_df['diagnosis_pr'] = metadata_df.diagnosis.map({\n",
    "    'nevus': 'nevus',\n",
    "    'melanoma': 'melanoma',\n",
    "    'basal cell carcinoma': 'bkl',\n",
    "    'seborrheic keratosis': 'bkl',\n",
    "    'solar lentigo': 'bkl',\n",
    "    'lentigo NOS': 'bkl',\n",
    "    'lentigo NOS': 'bkl'\n",
    "})\n",
    "mask = (metadata_df.benign_malignant == 'benign') & (metadata_df.diagnosis_pr != 'bkl')\n",
    "metadata_df.loc[mask, 'diagnosis_pr'] = 'nevus'\n",
    "metadata_df[\"path\"] = \"../images/\" + metadata_df['isic_id'] + \".jpg\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d1871c-d8f9-4bf5-880c-498dfcc2dc79",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hash(file_name):\n",
    "    image_tmp = Image.open(file_name)\n",
    "    md5hash = hashlib.md5(image_tmp.tobytes()).hexdigest()\n",
    "    return str(md5hash)\n",
    "\n",
    "def get_has_df(df):\n",
    "    image_hash = []\n",
    "    for _, row in df.iterrows():\n",
    "        image_hash.append(get_hash(row.path))\n",
    "    \n",
    "    return pd.DataFrame({\n",
    "        \"path\": df.path,\n",
    "        \"image_hash\": image_hash\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eadb0ad5-e317-4413-b45e-d5e45ebb7a7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def resize_image(image, resize=512):\n",
    "    w, h =  image.size\n",
    "\n",
    "    if h < w:\n",
    "        h_new = resize\n",
    "        w_new = int(h_new / h * w // 8 * 8)\n",
    "    else:\n",
    "        w_new = resize\n",
    "        h_new = int(w_new / w * h // 8 * 8)\n",
    "\n",
    "    image = image.resize((w_new, h_new))\n",
    "    return image\n",
    "\n",
    "def resize_images(df, path, size_thr = 512):\n",
    "    for _, row in df.iterrows():\n",
    "        img = Image.open(row.path)\n",
    "        w, h = img.size\n",
    "\n",
    "        if min(w, h) > size_thr:\n",
    "            img = resize_image(img, resize=size_thr)\n",
    "        img.save(os.path.join(path, row.isic_id + \".png\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1897de8b-80d7-4d16-b637-15d0a334d743",
   "metadata": {},
   "outputs": [],
   "source": [
    "hash_df = Parallel(n_jobs=mp.cpu_count())(delayed(get_has_df)(df)\n",
    "    for df in np.array_split(metadata_df, mp.cpu_count()*2))\n",
    "hash_df = pd.concat(hash_df).reset_index(drop=True)\n",
    "\n",
    "metadata_df = metadata_df.merge(\n",
    "    hash_df, how=\"left\", on=[\"path\"]\n",
    ")\n",
    "metadata_df = metadata_df.groupby('image_hash').first().reset_index(drop=True)\n",
    "\n",
    "metadata_df[\"diagnosis_pr_target\"] = metadata_df.diagnosis_pr.map({\n",
    "    \"nevus\": 0,\n",
    "    \"bkl\": 1,\n",
    "    \"melanoma\": 2\n",
    "})\n",
    "metadata_df = metadata_df[~metadata_df.diagnosis_pr.isna()].reset_index(drop=True)\n",
    "metadata_df = metadata_df.rename(columns={\n",
    "    'diagnosis_pr_target': 'target'\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b6318cc-037c-44b0-924a-aa66374bf11b",
   "metadata": {},
   "outputs": [],
   "source": [
    "resized_path = \"../external_images_resized\"\n",
    "os.makedirs(resized_path, exist_ok=True)\n",
    "\n",
    "Parallel(n_jobs=mp.cpu_count())(delayed(resize_images)(df, resized_path)\n",
    "    for df in np.array_split(metadata_df, mp.cpu_count()*2));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13176e9f-8a8b-47c9-b21d-272d104f6e5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata_df['path'] = resized_path + '/' + metadata_df['isic_id'] + '.png'\n",
    "metadata_df = metadata_df[\n",
    "    metadata_df['path'].apply(lambda x: os.path.exists(x))\n",
    "].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a463b700-b4eb-4c15-8e51-3c3536693e62",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba3525d7-ca20-455a-8b22-6f1b7ce0c106",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pretrain_df, val_pretrain_df = train_test_split(\n",
    "    metadata_df, test_size=0.2, shuffle=True, stratify=metadata_df.target, random_state=CONFIG['seed'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7fae112-8537-4d6d-bf78-8ec7bbef1faf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ISICDatasetSamplerMulticlass(\n",
    "    train_pretrain_df, transforms=data_transforms[\"train\"], process_target=True, n_classes=3)\n",
    "valid_dataset = ISICDatasetSimple(\n",
    "    val_pretrain_df, transforms=data_transforms[\"valid\"], process_target=True, n_classes=3)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], \n",
    "                          num_workers=10, shuffle=True, pin_memory=True, drop_last=True)\n",
    "valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], \n",
    "                          num_workers=10, shuffle=False, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fe034c2-91fa-4cfb-9d15-331216581555",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = setup_model(model_name, num_classes=3, device=device)\n",
    "print_trainable_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62069631-58b1-414b-9427-d58241183318",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], \n",
    "                       weight_decay=CONFIG['weight_decay'])\n",
    "scheduler = fetch_scheduler(optimizer, CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caa80743-b1ff-4fd1-9c3b-5a895abb966c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def criterion_mc(outputs, targets):\n",
    "    return nn.CrossEntropyLoss()(outputs, targets)\n",
    "\n",
    "get_nth_test_step = lambda x: 1\n",
    "\n",
    "def run_training_pretrain(\n",
    "        train_loader, valid_loader, model, optimizer, scheduler, device, num_epochs, \n",
    "        model_folder=None, model_name=\"\", seed=42, tolerance_max=15, criterion=criterion_mc, test_every_nth_step=get_nth_test_step):\n",
    "    set_seed(seed)\n",
    "    start = time.time()\n",
    "    best_model_wts = copy.deepcopy(model.state_dict())\n",
    "    best_epoch_score = -np.inf\n",
    "    history = defaultdict(list)\n",
    "    tolerance = 0\n",
    "    for epoch in range(1, num_epochs + 1): \n",
    "        test_every_nth_step = get_nth_test_step(epoch)\n",
    "        if tolerance > tolerance_max:\n",
    "            break\n",
    "        gc.collect()\n",
    "        train_epoch_loss, train_epoch_auroc = train_one_epoch(\n",
    "            model, \n",
    "            optimizer, \n",
    "            scheduler, \n",
    "            dataloader=train_loader, \n",
    "            device=CONFIG['device'],\n",
    "            CONFIG=CONFIG,\n",
    "            epoch=epoch, \n",
    "            criterion=criterion,\n",
    "            metric_function=multiclass_auroc, \n",
    "            num_classes=3)\n",
    "\n",
    "        if epoch % test_every_nth_step == 0:\n",
    "            val_epoch_loss, val_epoch_auroc, val_epoch_custom_metric = valid_one_epoch(\n",
    "                model, \n",
    "                valid_loader, \n",
    "                device=CONFIG['device'], \n",
    "                epoch=epoch, \n",
    "                optimizer=optimizer, \n",
    "                criterion=criterion, \n",
    "                use_custom_score=False,\n",
    "                metric_function=multiclass_auroc, \n",
    "                num_classes=3)\n",
    "        \n",
    "            history['Train Loss'].append(train_epoch_loss)\n",
    "            history['Valid Loss'].append(val_epoch_loss)\n",
    "            history['Train AUROC'].append(train_epoch_auroc)\n",
    "            history['Valid AUROC'].append(val_epoch_auroc)\n",
    "            history['Valid Kaggle metric'].append(val_epoch_custom_metric)\n",
    "            history['lr'].append( scheduler.get_lr()[0] )\n",
    "            \n",
    "            if best_epoch_score <= val_epoch_auroc:\n",
    "                tolerance = 0\n",
    "                print(f\"Validation AUROC Improved ({best_epoch_score} ---> {val_epoch_auroc})\")\n",
    "                best_epoch_score = val_epoch_auroc\n",
    "                best_model_wts = copy.deepcopy(model.state_dict())\n",
    "                if model_folder is not None:\n",
    "                    torch.save(model.state_dict(), os.path.join(model_folder, model_name))\n",
    "            else:\n",
    "                tolerance += 1\n",
    "            \n",
    "        print()\n",
    "    \n",
    "    end = time.time()\n",
    "    time_elapsed = end - start\n",
    "    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(\n",
    "        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))\n",
    "    print(\"Best AUROC: {:.4f}\".format(best_epoch_score))    \n",
    "    model.load_state_dict(best_model_wts)\n",
    "    return model, history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ab81fca-1d30-4d09-af67-f9140b87cb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model, history = run_training_pretrain(\n",
    "    train_loader, valid_loader, \n",
    "    model, optimizer, scheduler,\n",
    "    device=CONFIG['device'],\n",
    "    num_epochs=CONFIG['epochs'],\n",
    "    criterion=criterion_mc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95a867b1-bdf0-41f9-84c3-2090d48a3107",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), os.path.join(model_dir, str_model_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56fb7b9f-b1f3-418d-a6bf-360f1c2dbdb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train_dataset = ISICDatasetSimple(df_train, transforms=data_transforms[\"valid\"], process_target=True, n_classes=3)\n",
    "df_train_loader = DataLoader(df_train_dataset, batch_size=CONFIG['valid_batch_size'], \n",
    "                          num_workers=5, shuffle=False, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c835d5b7-1bcf-49c6-b91f-8efe5ce3f0d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def criterion(outputs, targets):\n",
    "    return nn.BCELoss()(outputs, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97717cc-efe8-4cf7-9535-036a9d510f11",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_epoch_loss, val_epoch_auroc, val_epoch_custom_metric, tmp_predictions_all, tmp_targets_all = valid_one_epoch(\n",
    "    model, \n",
    "    df_train_loader, \n",
    "    device=CONFIG['device'], \n",
    "    epoch=1, \n",
    "    optimizer=optimizer, \n",
    "    criterion=criterion, \n",
    "    use_custom_score=False,\n",
    "    metric_function=multiclass_auroc, \n",
    "    num_classes=3,\n",
    "    return_preds=True)\n",
    "\n",
    "df_train['old_set_0'] = tmp_predictions_all[:, 0]\n",
    "df_train['old_set_1'] = tmp_predictions_all[:, 1]\n",
    "df_train['old_set_2'] = tmp_predictions_all[:, 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "343ee890-fab3-4bed-929f-90e4b912c8b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train[['isic_id', 'old_set_0', 'old_set_1', 'old_set_2']].to_parquet('../data/artifacts/old_data_model_forecast_large.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8d6d57c-b5c9-4ed5-bd23-0a699faaf3dd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
